{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1c94e59",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14679cb1",
   "metadata": {},
   "source": [
    "# Exploring the strengths and weaknesses of Stable Diffusion XL 0.9\n",
    "This notebook aims at uncovering strengths and weaknesses of the current Stable Diffusion XL 0.9 model.\n",
    "\n",
    "**Note that it builds on [THIS NOTEBOOK](stable_diffusion_evaluation.ipynb) which generates the necessary data. Run this in order to follow along or just use your own data.**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "06650047",
   "metadata": {},
   "source": [
    "# Step 1: Loading the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913d879f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sliceguard import SliceGuard\n",
    "from renumics.spotlight import Image, Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f5a9669",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the dataset\n",
    "df = pd.read_json(\"sd_dataset_scored_embedded_parti.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea409d30",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the text and image embeddings from the dataframe\n",
    "clip_text_embeddings = np.vstack(df[\"clip_text_embedding\"])\n",
    "clip_image_embeddings = np.vstack(df[\"clip_image_embedding\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f90d6001",
   "metadata": {},
   "source": [
    "# Step 2: Category-based analysis\n",
    "Check if any **categories** are giving worse results based on the pre-computed CLIP Score."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95641b8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a metric that simply returns the precomputed metric\n",
    "def return_precomputed_metric(y, y_pred):\n",
    "    return y.mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a68317b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sg = SliceGuard()\n",
    "\n",
    "# Show the drop and support levels that could make sense\n",
    "sg.show(df, [\"category\"],\n",
    "               \"clip_score\",\n",
    "               \"clip_score\",\n",
    "               return_precomputed_metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5742b2f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find category specific issues\n",
    "issues = sg.find_issues(df, [\"category\"],\n",
    "               \"clip_score\",\n",
    "               \"clip_score\",\n",
    "               return_precomputed_metric,\n",
    "               min_support=50,\n",
    "               min_drop=0.5)\n",
    "sg.report(spotlight_dtype={\"image\": Image})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86fa2370",
   "metadata": {},
   "source": [
    "# Step 3: Challenge-based analysis\n",
    "Check if there are any **challenges** in image generation that are problematic for stable diffusion."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbb0a249",
   "metadata": {},
   "outputs": [],
   "source": [
    "sg = SliceGuard()\n",
    "\n",
    "# Show the drop and support levels that could make sense\n",
    "sg.show(df, [\"challenge\"],\n",
    "               \"clip_score\",\n",
    "               \"clip_score\",\n",
    "               return_precomputed_metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "943c99b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find challenge specific issues\n",
    "sg = SliceGuard()\n",
    "\n",
    "# Show the drop and support levels that could make sense\n",
    "# for the category feature\n",
    "sg.find_issues(df, [\"challenge\"],\n",
    "            \"clip_score\",\n",
    "            \"clip_score\",\n",
    "            return_precomputed_metric,\n",
    "            min_drop=1,\n",
    "            min_support=20)\n",
    "sg.report(spotlight_dtype={\"image\": Image})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c174a866",
   "metadata": {},
   "source": [
    "# Step 4: Challenge, Category interaction\n",
    "Check if there are combinations of **categories and challenges** that are especially challenging."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b12c85a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "sg = SliceGuard()\n",
    "\n",
    "# Show the drop and support levels that could make sense\n",
    "sg.show(df, [\"challenge\", \"category\"],\n",
    "               \"clip_score\",\n",
    "               \"clip_score\",\n",
    "               return_precomputed_metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5769abae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find issues based on combinations of category and challenge.\n",
    "sg = SliceGuard()\n",
    "\n",
    "sg.find_issues(df, [\"challenge\", \"category\"],\n",
    "            \"clip_score\",\n",
    "            \"clip_score\",\n",
    "            return_precomputed_metric,\n",
    "            min_drop=1,\n",
    "            min_support=20)\n",
    "sg.report(spotlight_dtype={\"image\": Image})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc991ca2",
   "metadata": {},
   "source": [
    "# Step 5: Analysis based on prompt embeddings\n",
    "Check if there are clusters in the **embedding space** of CLIP text embeddings that are especially challenging."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e5d89d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "sg = SliceGuard()\n",
    "\n",
    "# Show the drop and support levels that could make sense\n",
    "sg.show(df, [\"clip_text_embedding\"],\n",
    "       \"clip_score\",\n",
    "       \"clip_score\",\n",
    "       return_precomputed_metric,\n",
    "        precomputed_embeddings={\"clip_text_embedding\": clip_text_embeddings})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87d8d6ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "sg = SliceGuard()\n",
    "\n",
    "# Show the drop and support levels that could make sense\n",
    "issues = sg.find_issues(df, [\"clip_text_embedding\"],\n",
    "       \"clip_score\",\n",
    "       \"clip_score\",\n",
    "       return_precomputed_metric,\n",
    "        min_support=3,\n",
    "        min_drop=6,\n",
    "        precomputed_embeddings={\"clip_text_embedding\": clip_text_embeddings})\n",
    "sg.report(spotlight_dtype={\"image\": Image})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91e6e69f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
